{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "xvSGDbExff_I"
      },
      "source": [
        "# &#x1F4D1; **作业7- 基于BERT的文字问答**\n",
        "\n",
        "这是台大2022年机器学习课程作业7的样本代码\n",
        "\n",
        "这里有一些有用的链接：\n",
        "幻灯片:    [Link](https://docs.google.com/presentation/d/1H5ZONrb2LMOCixLY7D5_5-7LkIaXO6AGEaV2mRdTOMY/edit?usp=sharing)　Kaggle链接: [Link](https://www.kaggle.com/c/ml2022spring-hw7)　数据集下载: [Link](https://drive.google.com/uc?id=1AVgZvy3VFeg0fX-6WQJMHPVrx3A-M1kb)\n",
        "\n",
        "\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "WGOr_eS3wJJf"
      },
      "source": [
        "## &#x2728;任务描述\n",
        "\n",
        "- 中文抽取式问答\n",
        "    - 输入：段落 + 问题\n",
        "    - 输出：答案\n",
        "\n",
        "- 目标\n",
        "    - 学习如何使用transformers在下游任务上微调预训练模型"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "TJ1fSAJE2oaC"
      },
      "source": [
        "## &#x2728;下载数据"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YPrc4Eie9Yo5"
      },
      "outputs": [],
      "source": [
        "# 下载链接 1\n",
        "!gdown --id '1AVgZvy3VFeg0fX-6WQJMHPVrx3A-M1kb' --output hw7_data.zip\n",
        "\n",
        "# # 下载链接 2\n",
        "# # !gdown --id '1qwjbRjq481lHsnTrrF4OjKQnxzgoLEFR' --output hw7_data.zip\n",
        "\n",
        "# # 下载链接 3\n",
        "# # !gdown --id '1QXuWjNRZH6DscSd6QcRER0cnxmpZvijn' --output hw7_data.zip\n",
        "\n",
        "!unzip -o hw7_data.zip"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "TevOvhC03m0h"
      },
      "source": [
        "## &#x2728;安装Transforms所需的package\n",
        "\n",
        "Python包的说明文档:　https://huggingface.co/transformers/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tbxWFX_jpDom"
      },
      "outputs": [],
      "source": [
        "!pip install --no-dependencies transformers==4.5.0"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "8dKM4yCh4LI_"
      },
      "source": [
        "## &#x2728;导入所需package"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WOTHHtWJoahe"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import numpy as np\n",
        "import random\n",
        "import torch\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "from transformers import AdamW, BertForQuestionAnswering, BertTokenizerFast\n",
        "\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# 固定随机种植，增加可复现性\n",
        "def same_seeds(seed):\n",
        "\ttorch.manual_seed(seed)\n",
        "\tif torch.cuda.is_available():\n",
        "\t\ttorch.cuda.manual_seed(seed)\n",
        "\t\ttorch.cuda.manual_seed_all(seed)\n",
        "\tnp.random.seed(seed)\n",
        "\trandom.seed(seed)\n",
        "\ttorch.backends.cudnn.benchmark = False\n",
        "\ttorch.backends.cudnn.deterministic = True\n",
        "same_seeds(0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7pBtSZP1SKQO"
      },
      "outputs": [],
      "source": [
        "# 将 \"fp16_training \"改为 True，以支持自动混合精度训练 (fp16)\t\n",
        "fp16_training = False\n",
        "\n",
        "if fp16_training:\n",
        "    !pip install accelerate==0.2.0\n",
        "    from accelerate import Accelerator\n",
        "    accelerator = Accelerator(fp16=True)\n",
        "    device = accelerator.device\n",
        "\n",
        "# Python包说明文档:  https://huggingface.co/docs/accelerate/"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "2YgXHuVLp_6j"
      },
      "source": [
        "## &#x2728;加载预训练模型和标记符（Tokenizer）\n",
        "\n",
        "\n",
        "\n",
        "\n",
        " "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xyBCYGjAp3ym"
      },
      "outputs": [],
      "source": [
        "model = BertForQuestionAnswering.from_pretrained(\"bert-base-chinese\").to(device)\n",
        "tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-chinese\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "3Td-GTmk5OW4"
      },
      "source": [
        "## &#x2728;读取数据\n",
        "\n",
        "- 训练集: 31690 个问答对\n",
        "- 验证集: 4131 个问答对\n",
        "- 测试集: 4957 个问答对\n",
        "\n",
        "- {train/dev/test}_questions:\t字典列表，包含以下键：\n",
        "  - id (整数)\n",
        "  - paragraph_id (整数)\n",
        "  - question_text (字符)\n",
        "  - answer_text (字符)\n",
        "  - answer_start (整数)\n",
        "  - answer_end (整数)\n",
        "\n",
        "- {train/dev/test}_paragraphs: \n",
        "  - 字符串列表\n",
        "  - 问题中的段落 ID 对应于段落列表中的索引\n",
        "  - 一个段落可能被多个问题使用\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NvX7hlepogvu"
      },
      "outputs": [],
      "source": [
        "def read_data(file):\n",
        "    with open(file, 'r', encoding=\"utf-8\") as reader:\n",
        "        data = json.load(reader)\n",
        "    return data[\"questions\"], data[\"paragraphs\"]\n",
        "\n",
        "train_questions, train_paragraphs = read_data(\"hw7_train.json\")\n",
        "dev_questions, dev_paragraphs = read_data(\"hw7_dev.json\")\n",
        "test_questions, test_paragraphs = read_data(\"hw7_test.json\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "Fm0rpTHq0e4N"
      },
      "source": [
        "## &#x2728;标记符数据"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rTZ6B70Hoxie"
      },
      "outputs": [],
      "source": [
        "# 分别标记问题和段落\n",
        "# 将 \"add_special_tokens \"设为 False，因为在数据集 __getitem__ 中合并标记化的问题和段落时，将添加特殊标记。\n",
        "\n",
        "train_questions_tokenized = tokenizer([train_question[\"question_text\"] for train_question in train_questions], add_special_tokens=False)\n",
        "dev_questions_tokenized = tokenizer([dev_question[\"question_text\"] for dev_question in dev_questions], add_special_tokens=False)\n",
        "test_questions_tokenized = tokenizer([test_question[\"question_text\"] for test_question in test_questions], add_special_tokens=False) \n",
        "\n",
        "train_paragraphs_tokenized = tokenizer(train_paragraphs, add_special_tokens=False)\n",
        "dev_paragraphs_tokenized = tokenizer(dev_paragraphs, add_special_tokens=False)\n",
        "test_paragraphs_tokenized = tokenizer(test_paragraphs, add_special_tokens=False)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "Ws8c8_4d5UCI"
      },
      "source": [
        "## &#x2728;数据集和数据加载器"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xjooag-Swnuh"
      },
      "outputs": [],
      "source": [
        "class QA_Dataset(Dataset):\n",
        "    def __init__(self, split, questions, tokenized_questions, tokenized_paragraphs):\n",
        "        self.split = split\n",
        "        self.questions = questions\n",
        "        self.tokenized_questions = tokenized_questions\n",
        "        self.tokenized_paragraphs = tokenized_paragraphs\n",
        "        self.max_question_len = 40\n",
        "        self.max_paragraph_len = 150\n",
        "        \n",
        "        ##### TODO: 更改 doc_stride 的值 #####\n",
        "        self.doc_stride = 150\n",
        "\n",
        "        # 输入序列长度 = [CLS] + 问题 + [SEP] + 段落 + [SEP]\n",
        "        self.max_seq_len = 1 + self.max_question_len + 1 + self.max_paragraph_len + 1\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.questions)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        question = self.questions[idx]\n",
        "        tokenized_question = self.tokenized_questions[idx]\n",
        "        tokenized_paragraph = self.tokenized_paragraphs[question[\"paragraph_id\"]]\n",
        "\n",
        "\n",
        "        if self.split == \"train\":\n",
        "            # 将答案在段落文字中的起始/结束位置转换为标记化段落中的起始/结束位置  \n",
        "            answer_start_token = tokenized_paragraph.char_to_token(question[\"answer_start\"])\n",
        "            answer_end_token = tokenized_paragraph.char_to_token(question[\"answer_end\"])\n",
        "\n",
        "            # 将段落中包含答案的部分切成片段，就得到了一个窗口\n",
        "            mid = (answer_start_token + answer_end_token) // 2\n",
        "            paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))\n",
        "            paragraph_end = paragraph_start + self.max_paragraph_len\n",
        "            \n",
        "            # 切分问题/段落并添加特殊标记（101：CLS，102：SEP）\n",
        "            input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102] \n",
        "            input_ids_paragraph = tokenized_paragraph.ids[paragraph_start : paragraph_end] + [102]\t\t\n",
        "            \n",
        "            # 将标记化段落中答案的起始/结束位置转换为窗口中的起始/结束位置  \n",
        "            answer_start_token += len(input_ids_question) - paragraph_start\n",
        "            answer_end_token += len(input_ids_question) - paragraph_start\n",
        "            \n",
        "            # 填充序列并获得模型输入\n",
        "            input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n",
        "            return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), answer_start_token, answer_end_token\n",
        "\n",
        "        # 验证/测试\n",
        "        else:\n",
        "            input_ids_list, token_type_ids_list, attention_mask_list = [], [], []\n",
        "            \n",
        "            # 段落被分割成几个窗口，每个窗口的起始位置用步长 \"doc_stride \"隔开\n",
        "            for i in range(0, len(tokenized_paragraph), self.doc_stride):\n",
        "                \n",
        "                # 切分问题/段落并添加特殊标记（101：CLS，102：SEP）\n",
        "                input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102]\n",
        "                input_ids_paragraph = tokenized_paragraph.ids[i : i + self.max_paragraph_len] + [102]\n",
        "                \n",
        "                # 填充序列并获得模型输入\n",
        "                input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n",
        "                \n",
        "                input_ids_list.append(input_ids)\n",
        "                token_type_ids_list.append(token_type_ids)\n",
        "                attention_mask_list.append(attention_mask)\n",
        "            \n",
        "            return torch.tensor(input_ids_list), torch.tensor(token_type_ids_list), torch.tensor(attention_mask_list)\n",
        "\n",
        "    def padding(self, input_ids_question, input_ids_paragraph):\n",
        "        # 如果序列长度小于 max_seq_len，则为零\n",
        "        padding_len = self.max_seq_len - len(input_ids_question) - len(input_ids_paragraph)\n",
        "        # 词汇表中输入序列标记的索引\n",
        "        input_ids = input_ids_question + input_ids_paragraph + [0] * padding_len\n",
        "        # 表示输入第一和第二部分的分段标记符号索引。索引在 [0, 1] 中选择\n",
        "        token_type_ids = [0] * len(input_ids_question) + [1] * len(input_ids_paragraph) + [0] * padding_len\n",
        "        # 屏蔽，以避免对填充标记索引执行关注。屏蔽值在 [0, 1] 中选择\n",
        "        attention_mask = [1] * (len(input_ids_question) + len(input_ids_paragraph)) + [0] * padding_len\n",
        "        \n",
        "        return input_ids, token_type_ids, attention_mask\n",
        "\n",
        "train_set = QA_Dataset(\"train\", train_questions, train_questions_tokenized, train_paragraphs_tokenized)\n",
        "dev_set = QA_Dataset(\"dev\", dev_questions, dev_questions_tokenized, dev_paragraphs_tokenized)\n",
        "test_set = QA_Dataset(\"test\", test_questions, test_questions_tokenized, test_paragraphs_tokenized)\n",
        "\n",
        "train_batch_size = 32\n",
        "\n",
        "# 注意：请勿更改 dev_loader / test_loader 的批次大小！\n",
        "# 虽然批次大小=1，但它实际上是由同一质量保证对的多个窗口组成的批次\n",
        "train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, pin_memory=True)\n",
        "dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, pin_memory=True)\n",
        "test_loader = DataLoader(test_set, batch_size=1, shuffle=False, pin_memory=True)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "5_H1kqhR8CdM"
      },
      "source": [
        "## &#x2728;评估功能"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SqeA3PLPxOHu"
      },
      "outputs": [],
      "source": [
        "def evaluate(data, output):\n",
        "    \n",
        "    answer = ''\n",
        "    max_prob = float('-inf')\n",
        "    num_of_windows = data[0].shape[1]\n",
        "    \n",
        "    for k in range(num_of_windows):\n",
        "        # 通过选择最可能的起始位置/结束位置来获取答案\n",
        "        start_prob, start_index = torch.max(output.start_logits[k], dim=0)\n",
        "        end_prob, end_index = torch.max(output.end_logits[k], dim=0)\n",
        "        \n",
        "        # 答案的概率计算为起始概率和结束概率之和\n",
        "        prob = start_prob + end_prob\n",
        "        \n",
        "        # 如果计算出的概率大于之前的窗口，则替换答案\n",
        "        if prob > max_prob:\n",
        "            max_prob = prob\n",
        "            # 将令牌转换为字符（例如 [1920, 7032] --> \"大金\"）。\n",
        "            answer = tokenizer.decode(data[0][0][k][start_index : end_index + 1])\n",
        "    \n",
        "    # 删除答案中的空格（如 \"大金\"-->\"大金\"）。\n",
        "    return answer.replace(' ','')"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "rzHQit6eMnKG"
      },
      "source": [
        "## &#x2728;训练"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Q-B6ka7xoCM"
      },
      "outputs": [],
      "source": [
        "num_epoch = 1\n",
        "validation = True\n",
        "logging_step = 100\n",
        "learning_rate = 1e-4\n",
        "optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
        "\n",
        "if fp16_training:\n",
        "    model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) \n",
        "\n",
        "model.train()\n",
        "\n",
        "print(\"Start Training ...\")\n",
        "\n",
        "for epoch in range(num_epoch):\n",
        "    step = 1\n",
        "    train_loss = train_acc = 0\n",
        "    \n",
        "    for data in tqdm(train_loader):\t\n",
        "        # 将所有数据加载到 GPU\n",
        "        data = [i.to(device) for i in data]\n",
        "        \n",
        "        # 模型输入：input_ids、token_type_ids、attention_mask、start_positions、end_positions（注意：只有 \"input_ids \"是必填项）\n",
        "        # 模型输出：start_logits、end_logits、loss（提供 start_positions/end_positions 时返回）  \n",
        "        output = model(input_ids=data[0], token_type_ids=data[1], attention_mask=data[2], start_positions=data[3], end_positions=data[4])\n",
        "\n",
        "        # 选择最可能的起始位置/结束位置\n",
        "        start_index = torch.argmax(output.start_logits, dim=1)\n",
        "        end_index = torch.argmax(output.end_logits, dim=1)\n",
        "        \n",
        "        # 只有当 start_index 和 end_index 都正确时，预测才是正确的\n",
        "        train_acc += ((start_index == data[3]) & (end_index == data[4])).float().mean()\n",
        "        train_loss += output.loss\n",
        "        \n",
        "        if fp16_training:\n",
        "            accelerator.backward(output.loss)\n",
        "        else:\n",
        "            output.loss.backward()\n",
        "        \n",
        "        optimizer.step()\n",
        "        optimizer.zero_grad()\n",
        "        step += 1\n",
        "\n",
        "        ##### TODO: 应用线性学习率衰减 #####\n",
        "        # lr_this_step = learning_rate * (1 - (epoch * len(train_loader) + step) / (num_epoch * len(train_loader)))\n",
        "        # for param_group in optimizer.param_groups:\n",
        "        #     param_group['lr'] = lr_this_step\n",
        "        \n",
        "        # 打印过去记录步骤中的训练损耗和准确性\n",
        "        if step % logging_step == 0:\n",
        "            print(f\"Epoch {epoch + 1} | Step {step} | loss = {train_loss.item() / logging_step:.3f}, acc = {train_acc / logging_step:.3f}\")\n",
        "            train_loss = train_acc = 0\n",
        "\n",
        "    if validation:\n",
        "        print(\"Evaluating Dev Set ...\")\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            dev_acc = 0\n",
        "            for i, data in enumerate(tqdm(dev_loader)):\n",
        "                output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),\n",
        "                       attention_mask=data[2].squeeze(dim=0).to(device))\n",
        "                # 只有当答案文本完全匹配时，预测才是正确的\n",
        "                dev_acc += evaluate(data, output) == dev_questions[i][\"answer_text\"]\n",
        "            print(f\"Validation | Epoch {epoch + 1} | acc = {dev_acc / len(dev_loader):.3f}\")\n",
        "        model.train()\n",
        "\n",
        "# 保存模型及其配置文件到目录 \"saved_model \n",
        "# 即在 \"saved_model \"目录下有两个文件： \"pytorch_model.bin \"和 \"config.json\"。\n",
        "# 可使用「model = BertForQuestionAnswering.from_pretrained(\"saved_model\")」重新加载保存的模型。\n",
        "\n",
        "# print(\"Saving Model ...\")\n",
        "# model_save_dir = \"saved_model\" \n",
        "# model.save_pretrained(model_save_dir)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "kMmdLOKBMsdE"
      },
      "source": [
        "## &#x2728;测试"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U5scNKC9xz0C"
      },
      "outputs": [],
      "source": [
        "print(\"Evaluating Test Set ...\")\n",
        "\n",
        "result = []\n",
        "\n",
        "model.eval()\n",
        "with torch.no_grad():\n",
        "    for data in tqdm(test_loader):\n",
        "        output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),\n",
        "                       attention_mask=data[2].squeeze(dim=0).to(device))\n",
        "        result.append(evaluate(data, output))\n",
        "\n",
        "result_file = \"result.csv\"\n",
        "with open(result_file, 'w') as f:\t\n",
        "\t  f.write(\"ID,Answer\\n\")\n",
        "\t  for i, test_question in enumerate(test_questions):\n",
        "        # 用空字符串替换答案中的逗号（因为 csv 用逗号分隔）\n",
        "        # 以相同方式处理 kaggle 中的答案\n",
        "\t\t    f.write(f\"{test_question['id']},{result[i].replace(',','')}\\n\")\n",
        "\n",
        "print(f\"Completed! Result is in {result_file}\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "ML2022Spring - HW7.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
